using LinearAlgebra
using KrylovKit
using DataFrames
using CSV

println()
println("Starting drive...")
println()

# from experiment
tau_arr = [5.0e-6, 0.00031718750000000003, 0.000629375, 0.0009415625, 0.0012537499999999999, 0.0015659375, 0.001878125, 0.0021903125, 0.0025025, 0.0028146875, 0.0031268750000000003, 0.0034390625, 0.0037512500000000002, 0.0040634375, 0.004375625, 0.0046878125, 0.005]

function ggms(d::Int64)::Vector{}

    ggms = Vector{}();
    # first (d^2-d)/2 elements are the "R" operators
    for aa = 1:d
        bsa = spzeros(d,1); bsa[aa,1] = 1;
        for bb = (aa+1):d
            bsb = spzeros(d,1); bsb[bb,1] = 1;
            push!(ggms, (bsb*bsa'+bsa*bsb')/sqrt(2));
        end
    end

    # second (d^2-d)/2 elements are the "L" operators
    for aa = 1:d
        bsa = spzeros(d,1); bsa[aa,1] = 1;
        for bb = 1:(aa-1)
            bsb = spzeros(d,1); bsb[bb,1] = 1;
            push!(ggms, (bsb*bsa'-bsa*bsb')*(-1im)/sqrt(2) );
        end
    end

    # third, the d-1 cartan generators  (Cartan-Weyl basis)
    for mm = 1:(d-1)
        delem = sqrt(1/(mm*(mm+1))).*[ones(1,mm) -mm zeros(1,d-mm-1)];
        push!(ggms, sparse(d:-1:1, d:-1:1, vec(delem)););
    end

      push!(ggms, sparse(I,d,d)/sqrt(d));

    return ggms

end


function evolution_superoperator(
    GG::Vector,
    H::Matrix{T},
    lindblad_ops::Vector{Matrix{U}}
) where T <: Number where U <: Number

    dummy1 = zeros(ComplexF64, size(H)...)
    dummy2 = zeros(ComplexF64, size(H)...)
    dummy3 = zeros(ComplexF64, size(H)...)

    LdL_arr = map(L -> L'L, lindblad_ops)
    Ld_arr = map(L -> collect(L'), lindblad_ops)

    dggjj_ham = zeros(ComplexF64, size(H)...)
    dggjj_lind = zeros(ComplexF64, size(H)...)
    dggjj = zeros(ComplexF64, size(H)...)

    evol = zeros(length(GG), length(GG))
    for (jj,ggjj) in enumerate(GG)
        mul!(dummy1, H, ggjj);
        mul!(dummy2, ggjj, H)
        dggjj_ham .= -im .* (dummy1 .- dummy2)
        dggjj_lind .= mapreduce(.+, Iterators.zip(lindblad_ops, Ld_arr, LdL_arr); init=zeros(ComplexF64, size(ggjj)...)) do (L, Ld, LdL)
            mul!(dummy1, L, ggjj)
            mul!(dummy3, dummy1, Ld)
            mul!(dummy1, LdL, ggjj)
            mul!(dummy2, ggjj, LdL)
            return -dummy1 ./ 2 .- dummy2 ./ 2 .+ dummy3
        end
        dggjj .= dggjj_ham .+ dggjj_lind
        for (ii,ggii) in enumerate(GG)
            mul!(dummy1, ggii, dggjj)
            evol[ii,jj] = real(tr(dummy1))
        end
    end

    return evol
end


potential_array_N0 = potential_array_from_h5(string(potential_dir, "/N0$aberration_type.h5"));
potential_array_N1M = potential_array_from_h5(string(potential_dir, "/N1M$aberration_type.h5"));

step_size_arr = compute_step_size_arr(potential_array_N0)

p0_arr = zeros(length(tau_arr))
p1_arr = zeros(length(tau_arr))

trap_depth_factor = trap_depth/4
@show trap_depth_factor
potential_N0 = potential_gridvalues(potential_array_N0) * trap_depth_factor
potential_N1M = potential_gridvalues(potential_array_N1M) * trap_depth_factor

@time hamiltonian_N0 = dvr_hamiltonian(m_NaCs, step_size_arr, potential_N0) / (2pi*hbar);
hamiltonian_N1M = dvr_hamiltonian(m_NaCs, step_size_arr, potential_N1M) / (2pi*hbar);
println("Krylov eigensolver...")
@time energies_N0, states_N0, krylov_info = eigsolve(hamiltonian_N0, num_lvls, :SR, Float64, tol=krylov_tol, krylovdim=2num_lvls); # eigen(Matrix(hamiltonian_N0));
energies_N1M, states_N1M, krylov_info = eigsolve(hamiltonian_N1M, num_lvls, :SR, Float64, tol=krylov_tol, krylovdim=2num_lvls); # eigen(Matrix(hamiltonian_N1M));
states_N0 = hcat(states_N0[1:num_lvls]...);
states_N1M = hcat(states_N1M[1:num_lvls]...);
states_N1M *= Diagonal(sign.(states_N1M'states_N0));
energies_N0 = energies_N0[1:num_lvls]
energies_N1M = energies_N1M[1:num_lvls]
idx_to_position, position_to_idx = create_idx_pos_dict(potential_array_N0)

psi_gs = states_N0[:,1]
energy_offset = - psi_gs' * (hamiltonian_N0 - hamiltonian_N1M) * psi_gs
# First kron idx: ↑, second kron idx: ↓
# in inverse milliseconds (order of dynamics)
@time H_free = single_molecule_hamiltonian(
    energies_N0[1:num_lvls],
    energies_N1M[1:num_lvls],
    0.,
    energy_offset,
    0.,
    states_N0[:,1:num_lvls],
    states_N1M[:,1:num_lvls],
);
H_x = single_molecule_hamiltonian(
    energies_N0[1:num_lvls],
    energies_N1M[1:num_lvls],
    Omega_pulse,
    energy_offset,
    0.,
    states_N0[:,1:num_lvls],
    states_N1M[:,1:num_lvls],
);
H_y = single_molecule_hamiltonian(
    energies_N0[1:num_lvls],
    energies_N1M[1:num_lvls],
    Omega_pulse,
    energy_offset,
    phi_echo,
    states_N0[:,1:num_lvls],
    states_N1M[:,1:num_lvls],
);
H_drive = single_molecule_hamiltonian(
    energies_N0[1:num_lvls],
    energies_N1M[1:num_lvls],
    Omega_drive,
    energy_offset,
    phi_drive,
    states_N0[:,1:num_lvls],
    states_N1M[:,1:num_lvls],
);
GG = ggms(2num_lvls);

sz = [-1 0; 0 1]/2
L_deph = sqrt(2gam_deph) * kron(sz, I(num_lvls))
L_deph_motion = sqrt(2gam_deph_motion) * Matrix(kron(I(2), Diagonal(1:num_lvls)))
println("Computing evolution operators")
@time liouville_free = evolution_superoperator(GG, 2pi*H_free, [L_deph, L_deph_motion]);
@time liouville_x = evolution_superoperator(GG, 2pi*H_x, [L_deph, L_deph_motion]);
@time liouville_y = evolution_superoperator(GG, 2pi*H_y, [L_deph, L_deph_motion]);
@time liouville_drive = evolution_superoperator(GG, 2pi*H_drive, [L_deph, L_deph_motion]);

println("Matrix exponential")
@time evol_x_2 = exp(liouville_x * 1/(4Omega_pulse));
evol_x = exp(liouville_x * 1/(2Omega_pulse));
evol_y = exp(liouville_y * 1/(2Omega_pulse));

hb_om = energies_N0[2] - energies_N0[1];

p0_op = Diagonal([ones(num_lvls)..., zeros(num_lvls)...])
p1_op = I - p0_op
p0_op_vec = tr.(GG .* [p0_op])
p1_op_vec = tr.(GG .* [p1_op])

beta_w_units = beta/hb_om;
partition_sum = sum(exp.(-beta_w_units * energies_N0))
println("Ground state fraction = $(exp(-beta_w_units*energies_N0[1])/partition_sum)")

dtau = tau_arr[2] - tau_arr[1]
@assert isapprox(tau_arr, tau_arr[1]:dtau:(tau_arr[end]+dtau/2))
@time evol_tau1 = exp(liouville_drive * tau_arr[1])
@time evol_dtau = exp(liouville_drive * dtau)

println("Computing evolution")
flush(stdout)
@time for (idx_tau, tau) in enumerate(tau_arr)
    p0_tmp = zeros(1)
    p1_tmp = zeros(1)
    
    evol_drive = evol_dtau^(idx_tau-1)* evol_tau1
    rho0 = zeros(ComplexF64, size(H_x)...)
    for ho_level in 1:num_lvls
        psi0 = zeros(ComplexF64, size(H_x, 1));
        psi0[ho_level] = 1;
        state_probability = 1 / partition_sum * exp(-beta_w_units * energies_N0[ho_level])

        rho0 .+= state_probability .* psi0 .* psi0'
    end
    @assert tr(rho0) ≈ 1
    rhot = tr.(GG .* [rho0])

    rhot = evol_x_2 * rhot
    rhot = evol_drive * rhot
    rhot = evol_x_2 * rhot
    p0 = p0_op_vec'rhot
    p1 = p1_op_vec'rhot


    p0_arr[idx_tau] = p0
    p1_arr[idx_tau] = p1
end

function sum_chi2(data_df, theory)
    residues_unweighted = data_df.mean - theory
    error_arr = [res < 0 ? data_df.errorUpper[ii] : data_df.errorLower[ii] for (ii,res) in enumerate(residues_unweighted)]

    sum(residues_unweighted.^2 ./ error_arr.^2)
end

output_df = DataFrame(
    time=tau_arr,
    p0=p0_arr,
    p1=p1_arr,
)

mkpath(save_dir)
ofn = "$save_dir/drive-$pstring.csv"
CSV.write(ofn, output_df)
